import numpy as np
import torch
from params import get_args
from env.env import JSP_Env
from model.REINFORCE import REINFORCE
import time
import os

def test():
    for instance in os.listdir(args.test_dir):
        file = os.path.join(args.test_dir, instance)
        print(f'{file}')
        st = time.time()
        avai_op = env.load_instance(file)
        while True:
            data = env.get_graph_data()
            action_idx, _, _, _ = policy(avai_op, data, greedy=True)
            avai_op, _, done = env.step(avai_op[action_idx])
            if done:
                ed = time.time()
                ms = env.get_makespan()
                with open("./result/{}/test_result.txt".format(args.date),"a") as outfile:
                    outfile.write(f'instance : {file:60}, policy : {env.get_makespan():10}\t')
                    outfile.write(f'time : {ed - st:10}\n')
                break

if __name__ == '__main__':
    args = get_args()
    print(args)
    os.makedirs('./result/{}'.format(args.date), exist_ok=True)
    env = JSP_Env(args)
    policy = REINFORCE(args).to(args.device)

    policy.load_state_dict(torch.load(args.load_weight, map_location=args.device),False)
    with torch.no_grad():
        test()